TO DO

Project outline

In this project, we analyse determinants of song popularity from a dataset on Spotify tracks.

In particular, our original dataset covers 114000 tracks. Each track has 21 audio features associated with it, ranging from artist name, popularity, duration, genre, ‘acousticness’, and tempo. All measures that cannot be measured directly such as ‘acousticness’, ‘danceability’, ‘instrumentalness’, have been normalised to a scale of 0-1.

We feel it would be interesting to see what factors affect ‘popuarity’, and believe it is likely to be determined by the other regressors in the data set such as ‘energy’, ‘danceability’, ‘valence’ etc. This could produce valuable models by predicting which songs people will enjoy before they’ve become popular, based on the characteristic or ‘intrinsic’ value of the song and less so about the artist names attached to it. Hence this can help with ‘song recommendation’ features.

It is reasonable to assume each track is independent of another, given that songs are usually written based on new concepts. We can also assume they share the same probability distribution, since all songs are judged based on the same critera, all of which are normalised to the same scale. Hence it is reasonable to assune they are identically distrubuted.

Dataset source: https://www.kaggle.com/datasets/maharshipandya/-spotify-tracks-dataset

[JENNY] Data cleaning

Notes: - track_id unique - track_name not unique (keep different versions by different singers, rid of different versions by the same singer)

Load libraries and data

library(tidyverse) 
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.2 ──
## ✔ ggplot2 3.4.0      ✔ purrr   0.3.4 
## ✔ tibble  3.1.8      ✔ dplyr   1.0.10
## ✔ tidyr   1.2.1      ✔ stringr 1.4.1 
## ✔ readr   2.1.2      ✔ forcats 0.5.2
## Warning: package 'ggplot2' was built under R version 4.2.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(GGally) # ggpairs()
## Registered S3 method overwritten by 'GGally':
##   method from   
##   +.gg   ggplot2
library(corrplot) # corrplot()
## Warning: package 'corrplot' was built under R version 4.2.2
## corrplot 0.92 loaded
library(gridExtra) # grid.arrange()
## 
## Attaching package: 'gridExtra'
## 
## The following object is masked from 'package:dplyr':
## 
##     combine
library(ggplot2) # gm_scatterplot
library(tidymodels) # initial_split()
## Warning: package 'tidymodels' was built under R version 4.2.2
## ── Attaching packages ────────────────────────────────────── tidymodels 1.0.0 ──
## ✔ broom        1.0.1     ✔ rsample      1.1.0
## ✔ dials        1.1.0     ✔ tune         1.0.1
## ✔ infer        1.0.3     ✔ workflows    1.1.2
## ✔ modeldata    1.0.1     ✔ workflowsets 1.0.0
## ✔ parsnip      1.0.3     ✔ yardstick    1.1.0
## ✔ recipes      1.0.3
## Warning: package 'dials' was built under R version 4.2.2
## Warning: package 'infer' was built under R version 4.2.2
## Warning: package 'modeldata' was built under R version 4.2.2
## Warning: package 'parsnip' was built under R version 4.2.2
## Warning: package 'recipes' was built under R version 4.2.2
## Warning: package 'rsample' was built under R version 4.2.2
## Warning: package 'tune' was built under R version 4.2.2
## Warning: package 'workflows' was built under R version 4.2.2
## Warning: package 'workflowsets' was built under R version 4.2.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ gridExtra::combine() masks dplyr::combine()
## ✖ scales::discard()    masks purrr::discard()
## ✖ dplyr::filter()      masks stats::filter()
## ✖ recipes::fixed()     masks stringr::fixed()
## ✖ dplyr::lag()         masks stats::lag()
## ✖ yardstick::spec()    masks readr::spec()
## ✖ recipes::step()      masks stats::step()
## • Learn how to get started at https://www.tidymodels.org/start/
library(glmnet) # glmnet()
## Warning: package 'glmnet' was built under R version 4.2.2
## Loading required package: Matrix
## Warning: package 'Matrix' was built under R version 4.2.2
## 
## Attaching package: 'Matrix'
## 
## The following objects are masked from 'package:tidyr':
## 
##     expand, pack, unpack
## 
## Loaded glmnet 4.1-4
data_new <- read.csv("unique_tracks_genres.csv")
final_data <- select(data_new, -X, -track_id, -artists, -album_name, -track_name, -track_genre)

# convert into factors
final_data$explicit <- as.numeric(as.factor(final_data$explicit))-1 # 0 for FALSE, 1 for TRUE
final_data$mode <- as.integer(final_data$mode) # 0 for minor, 1 for major

# create new dummy showing if track has >1 genre
genres <- c("pop", "rock", "country", "jazz", "electronic", "classical", "world", "kids", "other", "rap")
final_data$two_genre <- as.numeric(rowSums(final_data[, genres]) == 2)
final_data$three_genre <- as.numeric(rowSums(final_data[, genres]) == 3)
final_data$four_genre <- as.numeric(rowSums(final_data[, genres]) == 4)
final_data$five_genre <- as.numeric(rowSums(final_data[, genres]) == 5)

Exporatary analysis

Summary and visual

summary(final_data)
##    popularity      duration_ms         explicit        danceability   
##  Min.   :  0.00   Min.   :      0   Min.   :0.00000   Min.   :0.0000  
##  1st Qu.: 21.00   1st Qu.: 173871   1st Qu.:0.00000   1st Qu.:0.4460  
##  Median : 35.00   Median : 215213   Median :0.00000   Median :0.5730  
##  Mean   : 34.76   Mean   : 231419   Mean   :0.08544   Mean   :0.5593  
##  3rd Qu.: 49.00   3rd Qu.: 267344   3rd Qu.:0.00000   3rd Qu.:0.6900  
##  Max.   :100.00   Max.   :5237295   Max.   :1.00000   Max.   :0.9850  
##      energy           key            loudness            mode       
##  Min.   :0.000   Min.   : 0.000   Min.   :-49.531   Min.   :0.0000  
##  1st Qu.:0.455   1st Qu.: 2.000   1st Qu.:-10.456   1st Qu.:0.0000  
##  Median :0.678   Median : 5.000   Median : -7.263   Median :1.0000  
##  Mean   :0.635   Mean   : 5.285   Mean   : -8.596   Mean   :0.6323  
##  3rd Qu.:0.857   3rd Qu.: 8.000   3rd Qu.: -5.142   3rd Qu.:1.0000  
##  Max.   :1.000   Max.   :11.000   Max.   :  4.532   Max.   :1.0000  
##   speechiness      acousticness    instrumentalness       liveness     
##  Min.   :0.0000   Min.   :0.0000   Min.   :0.0000000   Min.   :0.0000  
##  1st Qu.:0.0361   1st Qu.:0.0159   1st Qu.:0.0000000   1st Qu.:0.0985  
##  Median :0.0491   Median :0.1900   Median :0.0000886   Median :0.1330  
##  Mean   :0.0890   Mean   :0.3297   Mean   :0.1847155   Mean   :0.2197  
##  3rd Qu.:0.0870   3rd Qu.:0.6290   3rd Qu.:0.1530000   3rd Qu.:0.2830  
##  Max.   :0.9650   Max.   :0.9960   Max.   :1.0000000   Max.   :1.0000  
##     valence           tempo       time_signature       pop        
##  Min.   :0.0000   Min.   :  0.0   Min.   :0.000   Min.   :0.0000  
##  1st Qu.:0.2410   1st Qu.: 99.4   1st Qu.:4.000   1st Qu.:0.0000  
##  Median :0.4490   Median :122.0   Median :4.000   Median :0.0000  
##  Mean   :0.4633   Mean   :122.1   Mean   :3.897   Mean   :0.1571  
##  3rd Qu.:0.6760   3rd Qu.:140.1   3rd Qu.:4.000   3rd Qu.:0.0000  
##  Max.   :0.9950   Max.   :243.4   Max.   :5.000   Max.   :1.0000  
##       rock           country             jazz           electronic   
##  Min.   :0.0000   Min.   :0.00000   Min.   :0.00000   Min.   :0.000  
##  1st Qu.:0.0000   1st Qu.:0.00000   1st Qu.:0.00000   1st Qu.:0.000  
##  Median :0.0000   Median :0.00000   Median :0.00000   Median :0.000  
##  Mean   :0.1967   Mean   :0.06841   Mean   :0.08189   Mean   :0.252  
##  3rd Qu.:0.0000   3rd Qu.:0.00000   3rd Qu.:0.00000   3rd Qu.:1.000  
##  Max.   :1.0000   Max.   :1.00000   Max.   :1.00000   Max.   :1.000  
##    classical         world            kids             other       
##  Min.   :0.000   Min.   :0.000   Min.   :0.00000   Min.   :0.0000  
##  1st Qu.:0.000   1st Qu.:0.000   1st Qu.:0.00000   1st Qu.:0.0000  
##  Median :0.000   Median :0.000   Median :0.00000   Median :0.0000  
##  Mean   :0.065   Mean   :0.304   Mean   :0.03445   Mean   :0.1343  
##  3rd Qu.:0.000   3rd Qu.:1.000   3rd Qu.:0.00000   3rd Qu.:0.0000  
##  Max.   :1.000   Max.   :1.000   Max.   :1.00000   Max.   :1.0000  
##       rap            two_genre       three_genre        four_genre      
##  Min.   :0.00000   Min.   :0.0000   Min.   :0.00000   Min.   :0.000000  
##  1st Qu.:0.00000   1st Qu.:0.0000   1st Qu.:0.00000   1st Qu.:0.000000  
##  Median :0.00000   Median :0.0000   Median :0.00000   Median :0.000000  
##  Mean   :0.03698   Mean   :0.2338   Mean   :0.02991   Mean   :0.009724  
##  3rd Qu.:0.00000   3rd Qu.:0.0000   3rd Qu.:0.00000   3rd Qu.:0.000000  
##  Max.   :1.00000   Max.   :1.0000   Max.   :1.00000   Max.   :1.000000  
##    five_genre      
##  Min.   :0.000000  
##  1st Qu.:0.000000  
##  Median :0.000000  
##  Mean   :0.001672  
##  3rd Qu.:0.000000  
##  Max.   :1.000000
col_names <-names(final_data)
for (i in seq_along(col_names)){
  hist(final_data[,i], main=paste("Histogram of", col_names[[i]]))
}

Correlation plot

final_data_cor1 <- cor(final_data)
corrplot(final_data_cor1, method="square", col = rev(colorRampPalette(c("#B40F20", "#FFFFFF", "#2E3A87"))(100)), type="lower", tl.col="black", tl.srt=60, tl.cex = 0.6)

ggpairs plot

The features selected are selected based on high absolute correlation between factors in the correlation plot.

ggpairs(final_data, columns = c("popularity", "danceability", "loudness", "instrumentalness"), lower = list(continuous = "smooth"), upper = list(continuous = "cor"))

Standardised metrics

Popularity vs: danceability, energy, speechiness, acousticness, instrumentalness, liveness, valence

basic_plots <- function(x){
  # plot without transparency  
  plot_nt <- ggplot(final_data, aes(x = !!sym(x), y = popularity)) +
    geom_point(alpha = 0.1)
  # plot with transparency 
  plot_wt <- ggplot(final_data, aes(x = !!sym(x), y = popularity)) +
    geom_bin2d(alpha = 0.7) +
    scale_fill_gradientn(colors = c("#440154", "#30678D", "#35B778", "#FDE724", "#FFFFFF"))
  # Return both plots 
  return(list(plot_nt, plot_wt))
}

metrics <- c('danceability', 'energy', 'speechiness', 'acousticness', 'instrumentalness', 'liveness', 'valence')
for (i in metrics) {
  plots <- basic_plots(i)
  grid.arrange(plots[[1]], plots[[2]], ncol = 2)
}

Genre

# Assign the genre name based on the dummy variables
get_genre_name <- function(x) {
  ifelse(x["two_genre"] == 1, "2_genres",
    ifelse(x["three_genre"] == 1, "3_genres",
      ifelse(x["four_genre"] == 1, "4_genres",
        ifelse(x["five_genre"] == 1, "5_genres",
          ifelse(x["rock"] == 1, "rock",
            ifelse(x["country"] == 1, "country",
              ifelse(x["jazz"] == 1, "jazz",
                ifelse(x["electronic"] == 1, "electronic",
                  ifelse(x["classical"] == 1, "classical",
                    ifelse(x["world"] == 1, "world",
                      ifelse(x["kids"] == 1, "kids",
                        ifelse(x["other"] == 1, "other",
                          ifelse(x["rap"] == 1, "rap", "pop")))))))))))))
}
# Apply the function to each row of the data frame and create a new column with the genre names
temp_data <- data.frame(final_data)
temp_data$genre_name <- apply(final_data[, -1], 1, get_genre_name)

# Create a bar plot of mean popularity by genre
mean_popularity <- tapply(temp_data$popularity, temp_data$genre_name, mean)
barplot(mean_popularity, xlab = "Genre", ylab = "Mean Popularity", col = "steelblue", main = "Mean Popularity by Genre", las = 2, cex.names = 0.8)

Modelling

Set up the Tidymodels Framework

# Define X, y, data
X <- select(final_data, -1)
y <- final_data$popularity
data <- data.frame(y = y, X = X)

# Split data into training and test set
data_split <- initial_split(data)
data_train <- training(data_split)
data_test <- testing(data_split)

# Cross-validation for tuning the parameters
data_cv <- vfold_cv(data_train, v = 10)

# Pre-process the model
data_recipe <- data_train %>%
  recipe(y ~ .) %>%
  prep()

Baseline model

Simple baseline for comparison to the more sophisticated models. Here we have chosen linear regression.

baseline <- lm(y ~ X.explicit + X.danceability + X.instrumentalness, data = data_train)
predictions_baseline <- predict(baseline, newdata = data_test)

# Test metrics ------------------
RMSE_baseline <- sqrt(mean((data_test$y - predictions_baseline)^2))
RSQ_baseline <- cor(data_test$y, predictions_baseline)^2

# Print the value
print("Testing: ")
## [1] "Testing: "
cat("RMSE:", RMSE_baseline, "\n")
## RMSE: 18.89782
cat("R-squared:", RSQ_baseline, "\n")
## R-squared: 0.03682725
# Training metrics

# Get summary statistics ------------------
summary_stats <- summary(baseline)
# Extract RMSE and R-squared values
RMSE_baseline_train <- sqrt(mean(summary_stats$residuals^2))
RSQ_baseline_train <- summary_stats$r.squared

print("Training: ")
## [1] "Training: "
cat("RMSE:", RMSE_baseline_train, "\n")
## RMSE: 19.00617
cat("R-squared:", RSQ_baseline_train, "\n")
## R-squared: 0.03649003

Lasso / Ridge / Elastic-net

Non-baseline model that is (relatively) interpretable.

  • Defines a linear regression model with Lasso regularization using the linear_reg() function from the parsnip package
  • tune() used to specify the hyperparameters penalty (P) and mixture (M)
  • set_engine() used to specify the modeling engine used to fit the model (here we use glmnet)
  • the resulting object pen_reg_y is a model specification object that can be further used for model training, tuning and prediction
# Model specification = penalised linear regression
pen_reg_y <- linear_reg(penalty = tune('P'), mixture = tune('M')) %>%
  set_engine('glmnet')

# Set up the workflow
pen_reg_wf <- workflow() %>%
  add_recipe(data_recipe) %>%
  add_model(pen_reg_y)

# Tune the parameters
fit_pen_reg <- tune_grid(pen_reg_wf,
                         #grid = data.frame(P = 2^seq(-3, 2, by = 1),
                                           #M = seq(0, 1, by = 0.2)),
                         data_cv,
                         metrics = metric_set(rmse, mae, rsq),
                         control = control_grid(save_pred = TRUE))
fit_pen_reg %>% autoplot()  # plot the result for each value of the parameters

# Select the best model with the smallest cross-validation rmse
pen_reg_best <- fit_pen_reg %>%
  select_best(metric = 'rmse') 
pen_reg_best   # print the best model
## # A tibble: 1 × 3
##        P     M .config              
##    <dbl> <dbl> <chr>                
## 1 0.0199 0.710 Preprocessor1_Model07
### After getting the best parameter, can now return to the normal function

# Fit the final model
pen_reg_final <- finalize_model(pen_reg_y, pen_reg_best)

# Predict on the test data with the final model
pen_reg_test <- pen_reg_wf %>%
  update_model(pen_reg_final) %>%
  last_fit(split = data_split) %>%
  collect_metrics()
pen_reg_test  # print the result
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard      18.1   Preprocessor1_Model1
## 2 rsq     standard       0.115 Preprocessor1_Model1
P_best <- pen_reg_best[1]
M_best <- pen_reg_best[2]

glmnet_best <- glmnet(select(data_train, -1), data_train$y,
                    family = "gaussian",
                    alpha = M_best)
glmnet_lasso <- glmnet(select(data_train, -1), data_train$y,
                    family = "gaussian",
                    alpha = 1)
glmnet_ridge <- glmnet(select(data_train, -1), data_train$y,
                    family = "gaussian",
                    alpha = 0)

plot(glmnet_best, xvar = "lambda")

Interpretation:

Comparison to baseline model: Predictive accuracy better. This can be seen through the lower RMSE. The R-Squared has also improved from 0.03 to 0.11. Despite this, it is still very low. This gives us reason to think perhaps the relationship is not linear. Therefore, one of the models that followed that was complex and non-linear: Random forest

[PARK] Gradient descent (minibatch)

[JENNY] Random forest

Evaluation